!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Interface to Maxwell equation solver
!> \par History
!>      11/2020 created [mbrehm]
!> \author Martin Brehm
! **************************************************************************************************
MODULE maxwell_solver_interface
   USE cp_control_types, ONLY: maxwell_control_type
   USE cp_log_handling, ONLY: cp_get_default_logger, &
                              cp_logger_get_default_io_unit, &
                              cp_logger_type
   USE kinds, ONLY: dp
   USE pw_types, ONLY: pw_r3d_rs_type
   USE ISO_C_BINDING, ONLY: C_INT, C_DOUBLE

#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'maxwell_solver_interface'

! *** Public subroutines ***
   PUBLIC :: maxwell_solver

#if defined(__LIBMAXWELL)

   INTERFACE

      INTEGER(C_INT) FUNCTION libcp2kmw_setgrid(rx, ry, rz, ax, ay, az, bx, by, bz, cx, cy, cz) BIND(C, NAME='libcp2kmw_setgrid')
         USE ISO_C_BINDING, ONLY: C_INT, C_DOUBLE
         IMPLICIT NONE
         INTEGER(C_INT) :: rx, ry, rz
         REAL(C_DOUBLE) :: ax, ay, az, bx, by, bz, cx, cy, cz
      END FUNCTION libcp2kmw_setgrid

      INTEGER(C_INT) FUNCTION libcp2kmw_step(step, t) BIND(C, NAME='libcp2kmw_step')
         USE ISO_C_BINDING, ONLY: C_INT, C_DOUBLE
         IMPLICIT NONE
         INTEGER(C_INT) :: step
         REAL(C_DOUBLE) :: t
      END FUNCTION libcp2kmw_step

      INTEGER(C_INT) FUNCTION libcp2kmw_getzrow(buf, px, py, zmin, zmax) BIND(C, NAME='libcp2kmw_getzrow')
         USE ISO_C_BINDING, ONLY: C_INT, C_DOUBLE
         IMPLICIT NONE
         REAL(C_DOUBLE) :: buf(*)
         INTEGER(C_INT) :: px, py, zmin, zmax
      END FUNCTION libcp2kmw_getzrow

   END INTERFACE

#endif

CONTAINS

! **************************************************************************************************
!> \brief  Computes the external potential on the grid
!> \param maxwell_control the Maxwell control section
!> \param v_ee the realspace grid with the potential
!> \param sim_step current simulation step
!> \param sim_time current physical simulation time
!> \param scaling_factor a factor to scale the potential with
!> \date   12/2020
!> \author Martin Brehm
! **************************************************************************************************
   SUBROUTINE maxwell_solver(maxwell_control, v_ee, sim_step, sim_time, scaling_factor)
      TYPE(maxwell_control_type), INTENT(IN)             :: maxwell_control
      TYPE(pw_r3d_rs_type), POINTER                           :: v_ee
      INTEGER, INTENT(IN)                                :: sim_step
      REAL(KIND=dp), INTENT(IN)                          :: sim_time
      REAL(KIND=dp), INTENT(IN)                          :: scaling_factor

#if defined(__LIBMAXWELL)

      CHARACTER(len=*), PARAMETER                        :: routineN = 'maxwell_solver'

      INTEGER                                            :: handle, iounit, res, my_rank, num_pe, &
                                                            gid, master, tag, i, j, ip
      TYPE(cp_logger_type), POINTER                      :: logger

      INTEGER, DIMENSION(3)                              :: lbounds, lbounds_local, npoints, &
                                                            npoints_local, ubounds, ubounds_local
      REAL(C_DOUBLE), ALLOCATABLE, DIMENSION(:)          :: buffer

      MARK_USED(maxwell_control)
      MARK_USED(v_ee)
      MARK_USED(sim_step)
      MARK_USED(sim_time)

      CALL timeset(routineN, handle)
      NULLIFY (logger)
      logger => cp_get_default_logger()
      iounit = cp_logger_get_default_io_unit(logger)

      my_rank = v_ee%pw_grid%para%group%mepos
      num_pe = v_ee%pw_grid%para%group%num_pe
      gid = v_ee%pw_grid%para%group
      tag = 1

      lbounds = v_ee%pw_grid%bounds(1, :)
      ubounds = v_ee%pw_grid%bounds(2, :)
      npoints = v_ee%pw_grid%npts

      lbounds_local = v_ee%pw_grid%bounds_local(1, :)
      ubounds_local = v_ee%pw_grid%bounds_local(2, :)
      npoints_local = v_ee%pw_grid%npts_local

      ALLOCATE (buffer(lbounds(3):ubounds(3)))

      IF (my_rank == 0) THEN

         IF (iounit > 0) THEN
            WRITE (iounit, *) ""
            WRITE (iounit, *) "MAXWELL| Called, step = ", sim_step, " time = ", sim_time
         END IF

         res = libcp2kmw_setgrid( &
               ubounds(1) - lbounds(1) + 1, &
               ubounds(2) - lbounds(2) + 1, &
               ubounds(3) - lbounds(3) + 1, &
               v_ee%pw_grid%dh(1, 1)*(ubounds(1) - lbounds(1) + 1), &
               v_ee%pw_grid%dh(2, 1)*(ubounds(1) - lbounds(1) + 1), &
               v_ee%pw_grid%dh(3, 1)*(ubounds(1) - lbounds(1) + 1), &
               v_ee%pw_grid%dh(1, 2)*(ubounds(2) - lbounds(2) + 1), &
               v_ee%pw_grid%dh(2, 2)*(ubounds(2) - lbounds(2) + 1), &
               v_ee%pw_grid%dh(3, 2)*(ubounds(2) - lbounds(2) + 1), &
               v_ee%pw_grid%dh(1, 3)*(ubounds(3) - lbounds(3) + 1), &
               v_ee%pw_grid%dh(2, 3)*(ubounds(3) - lbounds(3) + 1), &
               v_ee%pw_grid%dh(3, 3)*(ubounds(3) - lbounds(3) + 1) &
               )

         res = libcp2kmw_step(sim_step, sim_time)

         IF (iounit > 0) THEN
            WRITE (iounit, *) "MAXWELL| Returned with value ", res
            WRITE (iounit, *) "MAXWELL| Distributing potential to MPI processes..."
         END IF

      END IF

      ! The following code block is copied from src/pw/realspace_grid_cube.F
      CALL gid%bcast(buffer(lbounds(3):ubounds(3)), 0)

      !master sends all data to everyone
      DO i = lbounds(1), ubounds(1)
         DO j = lbounds(2), ubounds(2)

            !only use data that is local to me - i.e. in slice of pencil I own
            IF ((lbounds_local(1) <= i) .AND. (i <= ubounds_local(1)) .AND. &
                (lbounds_local(2) <= j) .AND. (j <= ubounds_local(2))) THEN
               !allow scaling of external potential values by factor 'scaling' (SCALING_FACTOR in input file)
               v_ee%array(i, j, lbounds(3):ubounds(3)) = buffer(lbounds(3):ubounds(3))*scaling_factor
            END IF

         END DO
      END DO

      IF (iounit > 0) THEN
         WRITE (iounit, *) "MAXWELL| All done."
      END IF

      CALL timestop(handle)

#else

      MARK_USED(maxwell_control)
      MARK_USED(v_ee)
      MARK_USED(sim_step)
      MARK_USED(sim_time)
      MARK_USED(scaling_factor)

      CALL cp_abort(__LOCATION__, &
                    "The Maxwell solver interface requires CP2k to be compiled &
                     &with the -D__LIBMAXWELL preprocessor option.")

#endif

   END SUBROUTINE maxwell_solver

END MODULE maxwell_solver_interface

